-
Notifications
You must be signed in to change notification settings - Fork 11.8k
metal : optimize multi-sequence FA vec kernel #13493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@JohannesGaessler I am wondering if we can apply the same optimization to the CUDA FA kernels. I don't have currently a suitable CUDA machine that I can use for development, so just sharing some thoughts if you have some time to try out. The idea is that the FA-vec kernel can directly skip blocks for which the KQ mask is full of -INF. For single-sequence generations this does not happen so we won't see any improvement, but it should improve a lot multi-sequence generations and the upcoming SWA models. I took a quick look at the code and have the following suggestion: diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index d96e39212..171652dc8 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -175,6 +175,22 @@ static __global__ void flash_attn_vec_ext_f16(
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
+ // mask -INF blocks
+ half mask_max = -INF_HALF;
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+ // reuse the array for masking -INF blocks
+ mask_max = max(mask_max, maskh[j*ne11 + k_VKQ_0 + i_KQ]);
+ }
+ }
+
+ mask_max = warp_reduce_max(mask_max); // TODO: not sure what it the function, but simply warp-reduce the var
+ if (mask_max == -INF_HALF) {
+ continue;
+ }
+
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). It's not tested at all, but I hope mainly to illustrate the idea. There might be a better way to do it. Edit: Btw, it might be worth trying to force using the diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index d96e39212..3541b63ec 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -335,7 +351,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
- if (Q->ne[1] == 1) {
+ if (true) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false; The same logic should be possible to apply to the non-vec FA kernel, but this is less important for now. If you could give this a try and run the |
For token generation the bottleneck is I/O. Tensor cores work on matrix fragments of at least 8 rows/columns. Without GQA for <= 8 tokens each value from the KV cache is being loaded only once. So for <= 8 parallel sequences you can basically never skip any KV slices unless you use an implementation without tensor cores (the vec kernels do not). The vec kernels only work for a single token, for >1 tokens they would need to use KV values for >1 Q values and you would quickly run into the same problem where each KV value is basically always used for at least one sequence and you can never skip it completely. At most you can skip some of the compute but as long as you're I/O bound all of the compute pipelines are going to be underutilized anyways. However, if a model is using GQA then you can get a much higher arithmetic intensity for each KV value since you can apply it to multiple Q values for each sequence. In such a scenario an optimization for potentially skipping a KV slice could work because you're much less I/O bound. As it is the CUDA kernels entangle up to 64 Q columns though so for this optimization to work correctly the selection logic for the number of parallel columns needs to know that the tokens belong to different sequences. So the bottom line is that for CUDA this optimization will only make sense for at least
I have a machine with 6 RTX 4090s in my basement, I can give you access if it would be useful for your work. |
Just so I don't misunderstanding anything: multiple parallel sequences result in an increase in Also I forgot: for SWA specifically modifying the kernels to skip fully masked-out KV slices would work without issue. |
The CUDA kernel in |
I should have probably been more explicit about this: you do not gain any advantage in terms of I/O if you make this change. At most you will gain an advantage in terms of compute. The vec kernels would first need to receive the same GQA optimizations that the kernel in |
Generally yes, but there is also another use case. For example in #10860, using the provided Python runner script in sequential mode, would have The script there also allows to run in parallel mode - in that case yes,
No worries, I was looking for a quick test. Just to understand, is the |
For Ampere that kernel should only be used for batch size 1 and if not using GQA, but only because the mma implementation turned out to be faster due to the GQA optimization. |
Ok thanks, I think I understand better now.
Could you clarify what does fractional tiles mean in this case? To give some additional context, the main drawback of the unified KV cache implementation is that it can end up attending to many blocks that are masked by the KQ mask (i.e. cross-sequence attention), which we basically discard. Even though we don't need it, we still perform the computation over such -INF tiles. I am thinking of a way to "filter" the tiles that are already masked by the KQ mask so that we save some computation. I could be missing something (so feel free to ignore or correct me), but another idea is to run a quick kernel to create a list of the tiles that have at least one non-INF cell and then in the I am not following fully the explanation about the I/O of the kernel, but my thinking is very simple - the work area of the kernel is roughly: |
When I say "tile", I mean a tile of the output tensor. So VKQ in the case of FA. The standard way to parallelize a matrix multiplication is to assign some output tile to a CUDA block, the CUDA block then iterates over the input matrices to calculate said tile. The accumulators are kept in registers and you can simply write out the output tile. The problem is that modern GPUs are becoming increasingly "wide", meaning that the number of streaming multiprocessors has become much larger than in previous generations. This causes "tail effects" where the last wave of CUDA blocks cannot fully utilize the hardware; in the worst case scenario with 1.01 waves you basically lose half the performance. With a stream-k decomposition CUDA blocks can work on fractions of output tiles and as a consequence the start and end points in the continuous
One of the most basic strategies for good GPU performance is to achieve a high arithmetic intensity: you want to do as much work as possible for each data value that you load from memory. So for MMQ and FA I'm trying to work with large output tiles because the larger the output tiles are the more work can be done per data loaded from memory. The dimension that is the most difficult to scale for language models is Basically the problem is that right now the CUDA code always tries to work on as many What I'll do is add the GQA optimization to the vec kernels. Then it should be possible to use them for GQA models on all GPUs without a performance penalty relative to the mma kernel. For the vec kernels it's always possible to evaluate sequences independently from one another so an optimization to skip masked-out KQ slices would make sense. If the indices of relevant KQ slices were to be precomputed it would of course be a bit faster. If all sequences are disjointed from one another then there would never be a benefit to scaling up the |
Yes, this makes complete sense - thank you for the detailed answer. I think the reasoning is clear. I would flag one detail that IMO would warrant some extra looking into - the conclusions about the performance of the FA vec kernel in #12014 about increasing the arithmetic intensity of the kernel is valid mainly for the speculative decoding use case where the small batch of tokens is from the same sequence. However, for small batches with mixed sequences it might not be the case. The
I agree that marking the batches would bring the optimal performance in the 2 scenarios. However, this would come with some extra complexity in the batching logic. Not saying it's not worth it, but my guess would be to first explore the idea of:
This is very easy to implement in the FA-vec kernel, even without precomputing the indices (see the diff that I suggested earlier). And even if it hurts the speculative decoding case by ~5%, I think it might significantly improve the multi-sequence use case. So it would be a simple intermediate step before we decide if it is worth to mark the multi-sequence batches. |
ref #10860 #13488
This should largely resolve the text-generation performance for multiple sequences with large prompts on Metal. I think this practically achieves the same effect as PagedAttention. This PR implements it for the FA-vec kernel - we simply skip fully-masked KV cache blocks with the size of the simdgroup (32). The other FA kernel for BS >= 4 already has a similar optimizaion,
but I think it is not as optimal as it can be. It can now be improved by utilizing #12850 to precompute the skip conditions and pass those to the FA kernel as an extra tensor - this will be done in follow-up PR.(edit: nvm - it's already good enough).Note that this also fixes the TG performance for SWA models, without having to do the defrag (see #13194 (comment)).
To test this, we can use the
llama-batched-bench
tool like this. It's important to generate more than one large prompt (via the-npl
argument) in order to simulate a server with multiple slots. We observe that the TG speed is improved at large contexts.make -j && ./bin/llama-batched-bench -m ../models/qwen2.5-7b-coder/ggml-model-q4_k.gguf -c 40000 -b 2048 -ub 512 -npp 0,512,4096,8192 -ntg 32 -npl 2,3 -fa
main: n_kv_max = 40192, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16
main: n_kv_max = 40192, n_batch = 2048, n_ubatch = 512, flash_attn = 1, is_pp_shared = 0, n_gpu_layers = -1, n_threads = 16, n_threads_batch = 16